#############################################################
# Author: Mike Burnham, mlb6496@psu.edu
# R: 4.3.1
# OS: Windows 10
#
# Notes: This script reproduces all of the numbers in section
# 4 as well as figure 3
##############################################################

library(dplyr)
library(ggplot2) 
library(mltools)
library(dotwhisker)
library(MASS)
# install broom if necessary
#install.packages('broom')

sink("sect_4.log", append=TRUE, split=TRUE)
#########################
## Import and clean data
#########################
# Import datasets of classified tweets
# original classifications
og <- read.csv('./covid_classified_tweet_ids.csv')
# first zero shot
tweets <- read.csv("./threatmin1.csv")
# second zero shot
tweets2 <- read.csv("./threatmin2.csv")
# training tweets
training <- read.csv("./covid_training_sample.csv")

# import data on users
users <- read.csv('./users.csv')
# drop incomplete data and keep only users with robust ideology estimates
users <- users[complete.cases(users),]
users <- users[users['tweets'] >= 1 & users['Rhat'] <= 1.1,]

# group by users and sum number of threat minimizing tweets
zs1 <- tweets[c('id10', 'threatmin')] %>% 
  group_by(id10) %>%
  summarise(threatmin1 = sum(threatmin))

zs2 <- tweets2[c('id10', 'threatmin')] %>% 
  group_by(id10) %>%
  summarise(threatmin2 = sum(threatmin))

# join zero-shot labels to user data
users <- left_join(users, zs1, by = 'id10')
users <- left_join(users, zs2, by = 'id10')

# Create new user variables for analysis
users$pct_black <- users$nhb/users$population
users$pct_asn <- users$asian/users$population
users$pct_hisp <- users$hisp/users$population
users$rep_share <- users$rep_votes/users$totalvotes
users$infect_rate <- users$cases/users$population
users$death_rate <- users$deaths/users$population


########################################
## Mean values reported in section 5.2
########################################
print("Mean ideology of a tweet's author:")
mean(tweets$ideology)

print("Mean ideology where original labels and zs1 classifier disagree:")
mean(tweets[tweets$non_comp != tweets$threatmin,'ideology'])

print("Mean ideology where original labels and zs2 classifier disagree:")
mean(tweets2[tweets2$non_comp != tweets2$threatmin,'ideology'])


################################################
## Test set performance reported in section 5.2
################################################
# Mcc
print("Test set MCC Between original labels and hypothesis set 1:")
mcc(training$non_comp, training$threatmin)
print("Test set MCC Between original labels and hypothesis set 2:")
mcc(training$non_comp, training$threatmin2)
print("Test set MCC Between hypothesis set 1 and hypothesis set 2:")
mcc(training$threatmin, training$threatmin2)

# Accuracy
print("Hypothesis set 1 accuracy:")
sum(ifelse(training$non_comp == training$threatmin,1,0))/300
print("Hypothesis set 2 accuracy:")
sum(ifelse(training$non_comp == training$threatmin2,1,0))/300
sink()
######################
# NB models death rate
######################
# scale death rate
users$death_rate <- users$death_rate*100

fit_death_og <- glm.nb(non_compliant ~ 
                       # Ideology and Death
                       ideology +
                       death_rate + 
                       ideology*death_rate +
                       
                       # Offset
                       offset(log(tweets)) +
                       
                       # Age, education, density
                       p_65_up +
                       pbach_grad +
                       POPPCT_URBAN +
                       
                       # Race
                       pct_hisp +
                       pct_black +
                       pct_asn +
                       
                       # Income
                       log(med_inc) +

                       #Politics
                       rep_share +
                       
                       # Fixed Effects
                       state,
                     
                     data = users)


fit_death_zs1 <- glm.nb(threatmin1 ~ 
                       # Ideology and Death
                       ideology +
                       death_rate + 
                       ideology*death_rate +
                       
                       # Offset
                       offset(log(tweets)) +
                       
                       # Age, education, density
                       p_65_up +
                       pbach_grad +
                       POPPCT_URBAN +
                       
                       # Race
                       pct_hisp +
                       pct_black +
                       pct_asn +
                       
                       # Income
                       log(med_inc) +
                       
                       #Politics
                       rep_share +
                       
                       # Fixed Effects
                       state,
                     
                     data = users)

fit_death_zs2 <- glm.nb(threatmin2 ~ 
                       # Ideology and Death
                       ideology +
                       death_rate + 
                       ideology*death_rate +
                       
                       # Offset
                       offset(log(tweets)) +
                       
                       # Age, education, density
                       p_65_up +
                       pbach_grad +
                       POPPCT_URBAN +
                       
                       # Race
                       pct_hisp +
                       pct_black +
                       pct_asn +
                       
                       # Income
                       log(med_inc) +
                       
                       #Politics
                       rep_share +
                       
                       # Fixed Effects
                       state,
                     
                     data = users)


#####################
## Figure 3: DW Plot
#####################

ogdf <- broom::tidy(fit_death_og)
ogdf <- ogdf[ogdf$term %in% c('ideology', 'death_rate', 'ideology:death_rate'),]
ogdf$model <- 'Supervised'

zs1df <- broom::tidy(fit_death_zs1)
zs1df <- zs1df[zs1df$term %in% c('ideology', 'death_rate', 'ideology:death_rate'),]
zs1df$model <- 'NLI Hypothesis Set 1'

zs2df <- broom::tidy(fit_death_zs2)
zs2df <- zs2df[zs2df$term %in% c('ideology', 'death_rate', 'ideology:death_rate'),]
zs2df$model <- 'NLI Hypothesis Set 2'

dw_plot <- dwplot(rbind(ogdf, zs1df, zs2df),
       model_order = c('Supervised', 'NLI Hypothesis Set 1', 'NLI Hypothesis Set 2'),
       dot_args = list(aes(shape = model), size = 3, color = "black", fill = "black"),
       whisker_args = list(aes(linetype = model), color = "black")) %>%
  relabel_predictors(c(ideology = 'Ideology',
                       death_rate = 'Deaths Per 100',
                       `ideology:death_rate` = 'Ideology:Deaths')) +
  theme_classic(base_size = 16) +
  xlab('Coefficient Estimate') +
  geom_vline(xintercept = 0, color = 'black', linetype = 2) +
  theme(legend.position = c(.85, 0.12),
        legend.title = element_blank(),
        axis.text = element_text(size = 14, color = 'black'),
        axis.title.x = element_text(size = 16)) +
  scale_shape_discrete(name = 'Model',
                       breaks = c('Supervised', 'NLI Hypothesis Set 1', 'NLI Hypothesis Set 2')) +
  scale_fill_manual(values = c("black"), guide = "none") +
  scale_color_manual(values = c("black"), guide = "none") +
  guides(shape = guide_legend('Model'))

ggsave("./Figure_3.png", dw_plot, dpi = 200, width = 8, height = 6, units = "in")